package link.jfire.sql.function.impl;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import link.jfire.baseutil.StringUtil;
import link.jfire.baseutil.collection.StringCache;
import link.jfire.baseutil.collection.set.LightSet;
import link.jfire.baseutil.reflect.ReflectUtil;
import link.jfire.baseutil.simplelog.ConsoleLogFactory;
import link.jfire.baseutil.simplelog.Logger;
import link.jfire.baseutil.verify.Verify;
import link.jfire.sql.annotation.Column;
import link.jfire.sql.annotation.Id;
import link.jfire.sql.annotation.SqlIgnore;
import link.jfire.sql.annotation.TableEntity;
import link.jfire.sql.field.MapField;
import link.jfire.sql.field.impl.IntegerField;
import link.jfire.sql.field.impl.StringField;
import link.jfire.sql.field.impl.WLongField;
import link.jfire.sql.function.DAOBean;
import link.jfire.sql.function.LockMode;
import link.jfire.sql.util.DaoFactory;
import sun.misc.Unsafe;

@SuppressWarnings("restriction")
public class DAOBeanImpl implements DAOBean
{
    private static Logger           logger                = ConsoleLogFactory.getLogger();
    private String                  deleteSql;
    private String                  batchDeleteSql;
    private Class<?>                entityClass;
    private String                  getSql;
    private String                  getSqlInShare;
    private String                  getSqlForUpdate;
    private String                  saveSql;
    private String                  updateSql;
    // 代表数据库主键id的field
    private MapField                idField;
    private MapField[]              getFields;
    private MapField[]              insertOrUpdateFields;
    private int                     insertOrUpdateFieldNum;
    private long                    idOffset;
    private Unsafe                  unsafe                = ReflectUtil.getUnsafe();
    // 选择更新字段和对应的mapfield数组映射
    private Map<String, MapField[]> selectUpdateFieldsMap = new ConcurrentHashMap<>();
    // 选择更新字段和对应的sql语句映射
    private Map<String, String>     selectUpdateSqlMap    = new ConcurrentHashMap<>();
    // 选择读取字段和对应的sql语句映射
    private Map<String, String>     selectGetSqlMap       = new ConcurrentHashMap<>();
    // 选择读取字段和对应的mapfield数组
    private Map<String, MapField[]> selectGetFieldsMap    = new ConcurrentHashMap<>();
    private Map<String, String>     listGetSqlMap         = new ConcurrentHashMap<>();
    private Map<String, MapField[]> listGetFieldsMap      = new ConcurrentHashMap<>();
    private Map<String, MapField[]> listParamFieldsMap    = new ConcurrentHashMap<>();
    // 条件更新方法，更新字段-条件字段和对应的sql语句映射
    private Map<String, String>     paramUpdateSqlMap     = new ConcurrentHashMap<>();
    // 条件更新方法，更新字段-条件字段和对应的Mapfield映射
    private Map<String, MapField[]> paramUpdateFieldsMap  = new ConcurrentHashMap<>();
    // 存储类的属性名和其对应的Mapfield映射关系
    private Map<String, MapField>   fieldMap              = new HashMap<>();
    
    public DAOBeanImpl(Class<?> entityClass)
    {
        this.entityClass = entityClass;
        Field[] fields = ReflectUtil.getAllFields(entityClass);
        LightSet<MapField> set = new LightSet<>();
        for (Field each : fields)
        {
            if (each.isAnnotationPresent(SqlIgnore.class) || Map.class.isAssignableFrom(each.getType()) || List.class.isAssignableFrom(each.getType()) || each.getType().isInterface() || each.getType().isArray() || Modifier.isStatic(each.getModifiers()))
            {
                continue;
            }
            if (each.isAnnotationPresent(Column.class))
            {
                if (each.getAnnotation(Column.class).daoIgnore())
                {
                    continue;
                }
            }
            set.add(DaoFactory.buildMapField(each));
            if (each.isAnnotationPresent(Id.class))
            {
                idField = DaoFactory.buildMapField(each);
                idOffset = unsafe.objectFieldOffset(each);
            }
        }
        Verify.notNull(idField, "使用TableEntity映射的表必须由id字段，请检查{}", entityClass.getName());
        LightSet<MapField> tmp = new LightSet<>();
        for (MapField each : set)
        {
            if (each.saveIgnore())
            {
                continue;
            }
            tmp.add(each);
        }
        insertOrUpdateFields = tmp.toArray(MapField.class);
        insertOrUpdateFieldNum = insertOrUpdateFields.length;
        for (MapField each : insertOrUpdateFields)
        {
            fieldMap.put(each.getFieldName(), each);
        }
        getFields = set.toArray(MapField.class);
        buildSql();
    }
    
    private void buildSql()
    {
        TableEntity tableEntity = entityClass.getAnnotation(TableEntity.class);
        String tableName = tableEntity.name();
        StringCache cache = new StringCache();
        /******** 生成insertSql *******/
        cache.append("insert into ").append(tableName).append(" ( ");
        for (MapField each : insertOrUpdateFields)
        {
            cache.append(each.getColName()).append(',');
        }
        cache.deleteLast().append(") values (");
        cache.appendStrsByComma("?", insertOrUpdateFields.length);
        cache.append(')');
        saveSql = cache.toString();
        /******** 生成insertSql *******/
        /******** 生成updatesql *******/
        cache.clear();
        cache.append("update ").append(tableName).append(" set ");
        for (MapField each : insertOrUpdateFields)
        {
            cache.append(each.getColName()).append("=?,");
        }
        cache.deleteLast().append(" where ").append(idField.getColName()).append("=?");
        updateSql = cache.toString();
        /******** 生成updatesql *******/
        /******** 生成deletesql *****/
        cache.clear();
        cache.append("delete from ").append(tableName).append("  where ").append(idField.getColName()).append("=?");
        deleteSql = cache.toString();
        batchDeleteSql = "delete from " + tableName + " where " + idField.getColName() + " in (";
        /******** 生成deletesql *****/
        /******** 生成getSql ******/
        cache.clear();
        cache.append("select ");
        for (MapField each : getFields)
        {
            cache.append(each.getColName()).append(",");
        }
        cache.deleteLast().append(" from ").append(tableName).append(" where ").append(idField.getColName()).append("=?");
        getSql = cache.toString();
        getSqlInShare = getSql + " LOCK IN SHARE MODE";
        getSqlForUpdate = getSql + " FOR UPDATE";
        /******** 生成getSql ******/
        logger.debug("为类：{}生成的\rsave  语句是: {},\rdelete语句是: {},\rupdate语句是: {},\rget   语句是: {}\r", entityClass.getName(), saveSql, deleteSql, updateSql, getSql);
    }
    
    @Override
    public boolean delete(Object entity, Connection connection)
    {
        try (PreparedStatement pstat = connection.prepareStatement(deleteSql))
        {
            pstat.setObject(1, unsafe.getObject(entity, idOffset));
            pstat.executeUpdate();
            return true;
        }
        catch (SQLException e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @SuppressWarnings("unchecked")
    @Override
    public <T> T getById(Object pk, Connection connection)
    {
        try (PreparedStatement pStat = connection.prepareStatement(getSql))
        {
            logger.trace("执行的sql是{}", getSql);
            pStat.setObject(1, pk);
            ResultSet resultSet = pStat.executeQuery();
            if (resultSet.next())
            {
                Object entity = entityClass.newInstance();
                for (MapField each : getFields)
                {
                    each.setEntityValue(entity, resultSet);
                }
                return (T) entity;
            }
            else
            {
                return null;
            }
        }
        catch (SQLException | InstantiationException | IllegalAccessException e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @SuppressWarnings("unchecked")
    public <T> T getById(Object pk, Connection connection, String fieldNames)
    {
        MapField[] fields = selectGetFieldsMap.get(fieldNames);
        String sql = selectGetSqlMap.get(fieldNames);
        if (fields == null)
        {
            StringCache cache = new StringCache("select ");
            LightSet<MapField> set = new LightSet<>();
            for (String each : fieldNames.split(","))
            {
                MapField tmp = fieldMap.get(each);
                set.add(tmp);
                cache.append(tmp.getColName()).append(", ");
            }
            cache.append(idField.getColName());
            set.add(idField);
            cache.append(" from ").append(entityClass.getAnnotation(TableEntity.class).name()).append(" where ").append(idField.getColName()).append(" = ?");
            sql = cache.toString();
            fields = set.toArray(MapField.class);
            selectGetFieldsMap.put(fieldNames, fields);
            selectGetSqlMap.put(fieldNames, sql);
        }
        logger.trace("执行的sql语句是{}", sql);
        try (PreparedStatement pStat = connection.prepareStatement(sql))
        {
            pStat.setObject(1, pk);
            ResultSet resultSet = pStat.executeQuery();
            if (resultSet.next())
            {
                Object entity = entityClass.newInstance();
                for (MapField each : fields)
                {
                    each.setEntityValue(entity, resultSet);
                }
                return (T) entity;
            }
            else
            {
                return null;
            }
        }
        catch (SQLException | InstantiationException | IllegalAccessException e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @SuppressWarnings("unchecked")
    @Override
    public <T> List<T> getList(T entity, Connection connection, String paramFieldNames, String resultFieldNames)
    {
        String key = paramFieldNames + '-' + resultFieldNames;
        MapField[] getFields = listGetFieldsMap.get(key);
        MapField[] paramFields = listParamFieldsMap.get(key);
        String sql = listGetSqlMap.get(key);
        if (getFields == null)
        {
            StringCache cache = new StringCache("select ");
            LightSet<MapField> set = new LightSet<>();
            for (String each : resultFieldNames.split(","))
            {
                MapField tmp = fieldMap.get(each);
                set.add(tmp);
                cache.append(tmp.getColName()).append(", ");
            }
            cache.deleteEnds(2).append(" ");
            getFields = set.toArray(MapField.class);
            set.removeAll();
            cache.append(" from ").append(entityClass.getAnnotation(TableEntity.class).name()).append(" where ");
            for (String each : paramFieldNames.split(","))
            {
                MapField tmp = fieldMap.get(each);
                set.add(tmp);
                cache.append(tmp.getColName()).append("=? and ");
            }
            cache.deleteEnds(4);
            paramFields = set.toArray(MapField.class);
            listGetFieldsMap.put(key, getFields);
            listParamFieldsMap.put(key, paramFields);
            listGetSqlMap.put(key, cache.toString());
            sql = cache.toString();
        }
        logger.trace("执行的sql语句是{}", sql);
        try (PreparedStatement pStat = connection.prepareStatement(sql))
        {
            int index = 1;
            for (MapField each : paramFields)
            {
                each.setStatementValue(pStat, entity, index);
                index++;
            }
            ResultSet resultSet = pStat.executeQuery();
            List<T> entitys = new ArrayList<>();
            while (resultSet.next())
            {
                T tmp = (T) entityClass.newInstance();
                for (MapField each : getFields)
                {
                    each.setEntityValue(tmp, resultSet);
                }
                entitys.add(tmp);
            }
            return entitys;
        }
        catch (SQLException | InstantiationException | IllegalAccessException e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @Override
    public <T> void save(T entity, Connection connection)
    {
        Object idValue = unsafe.getObject(entity, idOffset);
        if (idValue == null)
        {
            // id值为null，执行插入操作
            insert(entity, connection);
        }
        else
        {
            // id有值，执行更新操作
            try (PreparedStatement pStat = connection.prepareStatement(updateSql))
            {
                for (int i = 0; i < insertOrUpdateFieldNum; i++)
                {
                    insertOrUpdateFields[i].setStatementValue(pStat, entity, i + 1);
                }
                idField.setStatementValue(pStat, entity, insertOrUpdateFieldNum + 1);
                pStat.executeUpdate();
            }
            catch (SQLException e)
            {
                throw new RuntimeException(e);
            }
        }
        
    }
    
    @Override
    public <T> void batchInsert(List<T> entitys, Connection connection)
    {
        try (PreparedStatement pStat = connection.prepareStatement(saveSql))
        {
            for (Object each : entitys)
            {
                for (int i = 0; i < insertOrUpdateFieldNum; i++)
                {
                    insertOrUpdateFields[i].setStatementValue(pStat, each, i + 1);
                }
                pStat.addBatch();
            }
            pStat.executeBatch();
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @Override
    public <T> int update(T entity, Connection connection, String fieldNames)
    {
        MapField[] fields = selectUpdateFieldsMap.get(fieldNames);
        String sql = selectUpdateSqlMap.get(fieldNames);
        if (fields == null)
        {
            StringCache cache = new StringCache("update ");
            cache.append(entityClass.getAnnotation(TableEntity.class).name()).append(" set ");
            LightSet<MapField> set = new LightSet<>();
            for (String each : fieldNames.split(","))
            {
                MapField tmp = fieldMap.get(each);
                set.add(tmp);
                cache.append(tmp.getColName()).append("=?, ");
            }
            cache.deleteEnds(2).append(" where ").append(idField.getColName()).append("=?");
            sql = cache.toString();
            fields = set.toArray(MapField.class);
            selectUpdateFieldsMap.put(fieldNames, fields);
            selectUpdateSqlMap.put(fieldNames, sql);
        }
        logger.trace("执行的sql语句是{}", sql);
        try (PreparedStatement pStat = connection.prepareStatement(sql))
        {
            int index = 1;
            for (MapField each : fields)
            {
                each.setStatementValue(pStat, entity, index);
                index++;
            }
            idField.setStatementValue(pStat, entity, index);
            return pStat.executeUpdate();
        }
        catch (SQLException e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @Override
    public <T> int update(T entity, Connection connection, String updateFields, String paramFields)
    {
        String key = updateFields + '-' + paramFields;
        String sql = paramUpdateSqlMap.get(key);
        MapField[] fields = paramUpdateFieldsMap.get(key);
        if (fields == null)
        {
            StringCache cache = new StringCache("update ");
            cache.append(entityClass.getAnnotation(TableEntity.class).name()).append(" set ");
            LightSet<MapField> set = new LightSet<>();
            for (String each : updateFields.split(","))
            {
                MapField tmp = fieldMap.get(each);
                set.add(tmp);
                cache.append(tmp.getColName()).append("=?, ");
            }
            cache.deleteEnds(2).append(" where ");
            for (String each : paramFields.split(","))
            {
                MapField tmp = fieldMap.get(each);
                set.add(tmp);
                cache.append(tmp.getColName()).append("=? and ");
            }
            cache.deleteEnds(4);
            sql = cache.toString();
            fields = set.toArray(MapField.class);
            paramUpdateSqlMap.put(key, sql);
            paramUpdateFieldsMap.put(key, fields);
        }
        logger.trace("执行的sql语句是{}", sql);
        try (PreparedStatement pStat = connection.prepareStatement(sql))
        {
            int index = 1;
            for (MapField each : fields)
            {
                each.setStatementValue(pStat, entity, index);
                index++;
            }
            return pStat.executeUpdate();
        }
        catch (SQLException e)
        {
            throw new RuntimeException(e);
        }
        
    }
    
    public <T> void insert(T entity, Connection connection)
    {
        try (PreparedStatement pStat = connection.prepareStatement(saveSql, Statement.RETURN_GENERATED_KEYS))
        {
            for (int i = 0; i < insertOrUpdateFieldNum; i++)
            {
                insertOrUpdateFields[i].setStatementValue(pStat, entity, i + 1);
            }
            pStat.executeUpdate();
            ResultSet resultSet = pStat.getGeneratedKeys();
            resultSet.next();
            if (idField instanceof IntegerField)
            {
                unsafe.putObject(entity, idOffset, resultSet.getInt(1));
            }
            else if (idField instanceof StringField)
            {
                unsafe.putObject(entity, idOffset, resultSet.getString(1));
            }
            else if (idField instanceof WLongField)
            {
                unsafe.putObject(entity, idOffset, resultSet.getLong(1));
            }
            else
            {
                throw new RuntimeException(StringUtil.format("id字段暂时支持Integer,Long,String.请检查{}", entity.getClass().getName()));
            }
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @SuppressWarnings("unchecked")
    @Override
    public <T> T getById(Object pk, Connection connection, LockMode mode)
    {
        String sql = mode == LockMode.SHARE ? getSqlInShare : getSqlForUpdate;
        try (PreparedStatement pStat = connection.prepareStatement(sql))
        {
            logger.trace("执行的sql是{}", sql);
            pStat.setObject(1, pk);
            ResultSet resultSet = pStat.executeQuery();
            if (resultSet.next())
            {
                Object entity = entityClass.newInstance();
                for (MapField each : getFields)
                {
                    each.setEntityValue(entity, resultSet);
                }
                return (T) entity;
            }
            else
            {
                return null;
            }
        }
        catch (SQLException | InstantiationException | IllegalAccessException e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @Override
    public int deleteByIds(String ids, Connection connection)
    {
        StringCache cache = new StringCache(batchDeleteSql);
        ArrayList<String> params = new ArrayList<>(16);
        for (String id : ids.split(","))
        {
            cache.append("?,");
            params.add(id);
        }
        cache.deleteLast().append(')');
        try (PreparedStatement pStat = connection.prepareStatement(cache.toString()))
        {
            logger.trace("执行的sql是{}", cache.toString());
            int index = 1;
            for (String each : params)
            {
                pStat.setObject(index++, each);
            }
            return pStat.executeUpdate();
        }
        catch (SQLException e)
        {
            throw new RuntimeException(e);
        }
    }
    
    @Override
    public int deleteByIds(int[] ids, Connection connection)
    {
        StringCache cache = new StringCache(batchDeleteSql);
        ArrayList<Integer> params = new ArrayList<>(16);
        for (int id : ids)
        {
            cache.append("?,");
            params.add(id);
        }
        cache.deleteLast().append(')');
        try (PreparedStatement pStat = connection.prepareStatement(cache.toString()))
        {
            logger.trace("执行的sql是{}", cache.toString());
            int index = 1;
            for (Integer each : params)
            {
                pStat.setObject(index++, each);
            }
            return pStat.executeUpdate();
        }
        catch (SQLException e)
        {
            throw new RuntimeException(e);
        }
    }
    
}
